---
title: "pedagogical plot"
format: html
editor: visual
---

## Pedagogical plot

```{r}
library(devtools)
install.packages("lme4", type = "source")
install_github("naoki-egami/dsl", ref = "update-Aug2024", dependencies = TRUE, force = TRUE)
library(dsl)
library(janitor)
library(tidyverse)
library(broom)
library(jtools)
library(tidyverse)
library(ggplot2)
library(broom.mixed)
library(estimatr)
```

```{r}
won_sampling <- function(won_protest_machine, N, seed){
  won_protest_machine$violence_truth <- won_protest_human$violence
  set.seed(seed)
  # Sample with replacement from the rows of the data frame
  sample_won <- won_protest_machine[sample(nrow(won_protest_machine), size = nrow(won_protest_machine), replace = TRUE), ]
  
  # DSL iterations
  sample_won$row = 1:nrow(sample_won)
  
  sub_sample_indices <- sample(nrow(sample_won), size = (nrow(sample_won) - N))
  sub_sample_won <- sample_won
  sub_sample_won$row <- seq_len(nrow(sub_sample_won))
  # Assign NA to our labeled observations 
  sub_sample_won$violence_truth[sub_sample_indices] <- NA
  return(sub_sample_won)
}
```

```{r}
won_machine_df <- read.csv(file = "~/Downloads/DSL images/WonetAl_result.csv")
won_human_df <- read.table(file="~/Downloads/DSL images/annot_test.txt", sep="\t", quote="", comment.char="")
# some basic cleaning
won_human_df <- as.data.frame(won_human_df)
won_human_df <- won_human_df %>% row_to_names(row_number = 1)
won_machine_df['protest_h'] <- won_human_df['protest']
won_protest_human <- subset(won_human_df,protest == 1)
won_protest_machine <- subset(won_machine_df,protest_h == 1)

won_protest_machine['sign'] <- won_protest_human['sign']
won_protest_machine['photo'] <- won_protest_human['photo']
won_protest_machine['fire'] <- won_protest_human['fire']
won_protest_machine['police'] <- won_protest_human['police']
won_protest_machine['children'] <- won_protest_human['children']
won_protest_machine['group_20'] <- won_protest_human['group_20']
won_protest_machine['flag'] <- won_protest_human['flag']
won_protest_machine['night'] <- won_protest_human['night']
won_protest_machine['shouting'] <- won_protest_human['shouting']
```

```{r}
# number of times we want to run this
k = 40
set.seed(676677)
seeds <- sample(x=1:99999,size=k)


N = 200


multiple_runs_police_200 <- data.frame(matrix(nrow = 5,ncol = k), row.names = c("estimate", "std.error", "conf.low", "conf.high", "p.value"))
smultiple_runs_police_200 <- data.frame(matrix(nrow = 5,ncol = k), row.names = c("estimate", "std.error", "conf.low", "conf.high", "p.value"))
rmultiple_runs_police_200 <- data.frame(matrix(nrow = 5,ncol = k), row.names = c("estimate", "std.error", "conf.low", "conf.high", "p.value"))


for(i in 1:k){
  df <- won_sampling(won_protest_machine, N, seeds[i])
  dsl_model_i <- dsl(model = "lm",  formula = violence_truth ~ sign + photo + fire + police + children + group_20 + flag + night + shouting,   predicted_var = "violence_truth", prediction = "violence",   data = df, tuning = TRUE)
  d <- data.frame(summary(dsl_model_i))
  multiple_runs_police_200[,i] <- c(d$Estimate[5], d$Std..Error[5], d$CI.Lower[5], d$CI.Upper[5], d$p.value[5])
  dsl_model_i <- NA
  d <- NA
 
  sub_model_i <- lm_robust(formula = violence_truth ~ sign + photo + fire + police + children + group_20 + flag + night + shouting,   data = df)
  d<- sub_model_i %>% tidy(conf.int = TRUE)
  d<- as.data.frame(d)
  smultiple_runs_police_200[,i] <- c(d$estimate[5], d$std.error[5], d$conf.low[5], d$conf.high[5], d$p.value[5])
  
  # surrogate model
  sur_model_i <- lm_robust(formula = violence ~ sign + photo + fire + police + children + group_20 + flag + night + shouting,   data = df)
  d<- sur_model_i %>% tidy(conf.int = TRUE)
  d<- as.data.frame(d)
  rmultiple_runs_police_200[,i] <- c(d$estimate[5], d$std.error[5], d$conf.low[5], d$conf.high[5], d$p.value[5])
  sur_model_i <- NA
  d <- NA
  print(i)

}
```

```{r}
dsl_50 <- as.data.frame(t(multiple_runs_police_200))
sub_50 <- as.data.frame(t(smultiple_runs_police_200))
resnet_50 <- as.data.frame(t(rmultiple_runs_police_200))

won_protest_machine$violence_truth <- won_protest_human$violence
truth_model <- lm_robust(formula = violence_truth ~ sign + photo + fire + police + children + group_20 + flag + night + shouting,   data = won_protest_machine)

truth_tidy <- truth_model  %>% tidy(conf.int = TRUE)
t_police <- truth_tidy$estimate[5]
t_up_police <- truth_tidy$conf.high[5]
t_l_police <- truth_tidy$conf.low[5]

benchmark_50 <- data.frame("estimate" = c(t_police), "std.error"= c(NA), "conf.low" = c(t_l_police), "conf.high" = c(t_up_police), "p.value"= c(NA), "model" = "Benchmark")
row.names(benchmark_50) <- "X0"


# Add a column to identify the model type
dsl_50$model <- "DSL"
sub_50$model <- "Sub-sample"
resnet_50$model <- "ResNet"


# Combine the three data frames into one
df_combined <- rbind(
  cbind(sample = rownames(dsl_50), dsl_50),
  cbind(sample = rownames(sub_50), sub_50),
  cbind(sample = rownames(resnet_50), resnet_50),
  cbind(sample = rownames(benchmark_50), benchmark_50)
)

# Convert the 'model' and 'sample' columns to factors for better plotting
df_combined$model <- factor(df_combined$model, levels = c("DSL", "Sub-sample", "ResNet", "Benchmark"))
df_combined$sample <- factor(df_combined$sample, levels = unique(rownames(dsl_50)))
```

```{r}
# Adjust y positions for each model within a group
df_combined <- df_combined %>%
  mutate(
    yposition = as.numeric(sample) * 10 + case_when(
      model == "DSL" ~ -1.5,        # DSL slightly below
      model == "Sub-sample" ~ 0,    # Subsample centered
      model == "ResNet" ~ 1.5
    )
  )

df_combined$contains_truth <- ifelse(df_combined$conf.low <= t_police & df_combined$conf.high >= t_police, TRUE, FALSE)


df_combined$yposition[121] <- 1

# Plot with increased y-axis distance and shapes for models
ggplot(df_combined, aes(y = yposition, x = estimate, color = model, shape = model)) +
  geom_vline(xintercept = t_police, color = "black", linetype = "dashed", size = 0.4) +
  geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0.4, size = 0.8) +  # Adjust height and line thickness
  geom_point(size = 2) +  # Larger points for visibility
  scale_y_continuous(breaks = 3 * (1:length(unique(df_combined$sample))), labels = unique(df_combined$sample)) +  # Increased spacing
  scale_color_manual(values = c("DSL" = "tomato2", "Sub-sample" = "gold2", "ResNet" = "deepskyblue", "Benchmark" = "black")) +
  scale_shape_manual(values = c("DSL" = 16, "Sub-sample" = 17, "ResNet" = 18, "Benchmark" = 15)) +  # Different shapes for each model
  labs(title = "Confidence intervals for variable \"police\" (Won et al, 2017)",
       x = "Estimate",
       y = NULL,  
       color = "Model",
       shape = "Model") +
  theme_bw() +
  theme(axis.text.y = element_blank(),  # Keep group labels visible
        axis.ticks.y = element_blank(),  
        legend.position = "bottom",
        panel.grid.major.y= element_blank(),
        panel.grid.minor.y= element_blank())  # Remove horizontal grid lines

```

```{r}
ggplot(df_combined, aes(y = yposition, x = estimate, color = model, shape = model)) +
  # Vertical dashed line for truth value
  geom_vline(xintercept = t_police, color = "black", linetype = "dashed", size = 0.4) +
  
  # Error bars with conditional alpha
  geom_errorbarh(aes(xmin = conf.low, 
                     xmax = conf.high, 
                     alpha = contains_truth), 
                 height = 0.4, 
                 size = 0.8) +  
  
  # Points with conditional alpha
  geom_point(aes(alpha = contains_truth), size = 2) +  
  
  # Y-axis scaling
  scale_y_continuous(breaks = 10 * (1:length(unique(df_combined$sample))), 
                     labels = unique(df_combined$sample)) +  
  
  # Color and shape scales
  scale_color_manual(values = c("DSL" = "tomato2", "Sub-sample" = "gold2", "ResNet" = "deepskyblue", "Benchmark" = "black")) +
  scale_shape_manual(values = c("DSL" = 16, "Sub-sample" = 17, "ResNet" = 18, "Benchmark" = 15)) +
  
  
  # Set alpha scale for points and error bars
  scale_alpha_discrete(range = c(0.3, 1), guide = "none") +
  
  # Labels and theme
  labs(title = "Confidence intervals for variable \"police\" (Won et al., 2017)",
       x = "Estimate",
       y = NULL,  
       color = "Model",
       shape = "Model") +
  theme_bw() +
  theme(axis.text.y = element_blank(),
        plot.title = element_text(size = 14),
         axis.title = element_text(size = 13),
        axis.ticks.y = element_blank(),  
        legend.position = "bottom",
        panel.grid.major.y = element_blank(),
        panel.grid.minor.y = element_blank(),
        axis.text = element_text(size = 13),
        legend.title = element_text(size = 13),
        legend.text = element_text(size = 13))
```
